{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Zero Shot NER\n",
    "\n",
    "\n",
    "![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/nlu/blob/master/examples/colab/healthcare/medical_named_entity_recognition/zero_shot_ner.ipynb)\n",
    "\n",
    "Based on John Snow Labs Enterprise-NLP [ZeroShotNerModel](https://nlp.johnsnowlabs.com/docs/en/licensed_annotators#zeroshotnermodel)       \n",
    "This architecture is based on `RoBertaForQuestionAnswering`.\n",
    "Zero shot models excel at generalization, meaning that the model can accurately predict entities in very different data sets without the need to fine tune the model or train from scratch for each different domain.\n",
    "Even though a model trained to solve a specific problem can achieve better accuracy than a zero-shot model in this specific task,\n",
    "it probably won’t be be useful in a different task.\n",
    "That is where zero-shot models shows its usefulness by being able to achieve good results in various domains.\n"
   ],
   "metadata": {
    "id": "OkOiSHgdV1yK"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from johnsnowlabs import nlp\n",
    "\n",
    "# After uploading your license run this to install all licensed Python Wheels and pre-download Jars the Spark Session JVM\n",
    "nlp.install()"
   ],
   "metadata": {
    "id": "2-zF0lmNP7DW"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Load the Model"
   ],
   "metadata": {
    "id": "4ork95BOWO6r"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "enterprise_zero_shot_ner = nlp.load('en.zero_shot.ner_roberta')\n",
    "enterprise_zero_shot_ner"
   ],
   "metadata": {
    "id": "CL3jxSQTdBgT",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "3f7cafa1-01a1-4639-a2cf-ce51949e41f4"
   },
   "execution_count": 5,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Warning::Spark Session already created, some configs may not take.\n",
      "Warning::Spark Session already created, some configs may not take.\n",
      "zero_shot_ner_roberta download started this may take some time.\n",
      "Approximate size to download 438.9 MB\n",
      "[OK!]\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'zero_shot_ner': ZeroShotNerModel_16f413aaabeb,\n",
       " 'document_assembler': DocumentAssembler_f75df92df716,\n",
       " 'tokenizer': Tokenizer_8458d8f323fc,\n",
       " 'chunk_converter_licensed@entities': NerConverterInternal_6695cdfd63eb}"
      ]
     },
     "metadata": {},
     "execution_count": 5
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Configure entity definitions"
   ],
   "metadata": {
    "id": "yLzY3sL3WQmU"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "enterprise_zero_shot_ner['zero_shot_ner'].setEntityDefinitions(\n",
    "    {\n",
    "        \"PROBLEM\": [\n",
    "            \"What is the disease?\",\n",
    "            \"What is his symptom?\",\n",
    "            \"What is her disease?\",\n",
    "            \"What is his disease?\",\n",
    "            \"What is the problem?\",\n",
    "            \"What does a patient suffer\",\n",
    "            \"What was the reason that the patient is admitted to the clinic?\",\n",
    "        ],\n",
    "        \"DRUG\": [\n",
    "            \"Which drug?\",\n",
    "            \"Which is the drug?\",\n",
    "            \"What is the drug?\",\n",
    "            \"Which drug does he use?\",\n",
    "            \"Which drug does she use?\",\n",
    "            \"Which drug do I use?\",\n",
    "            \"Which drug is prescribed for a symptom?\",\n",
    "        ],\n",
    "        \"ADMISSION_DATE\": [\"When did patient admitted to a clinic?\"],\n",
    "        \"PATIENT_AGE\": [\n",
    "            \"How old is the patient?\",\n",
    "            \"What is the gae of the patient?\",\n",
    "        ],\n",
    "    }\n",
    ")\n"
   ],
   "metadata": {
    "id": "aTeTPgdJ2lEF",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "8bc0f37c-11d7-47c6-9c62-7c97bfa87b3b"
   },
   "execution_count": 6,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "ZeroShotNerModel_16f413aaabeb"
      ]
     },
     "metadata": {},
     "execution_count": 6
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Predict on some sample text with entities"
   ],
   "metadata": {
    "id": "ZKIfSDTRWS8W"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "\n",
    "df = enterprise_zero_shot_ner.predict(\n",
    "    [\n",
    "        \"The doctor pescribed Majezik for my severe headache.\",\n",
    "        \"The patient was admitted to the hospital for his colon cancer.\",\n",
    "        \"27 years old patient was admitted to clinic on Sep 1st by Dr. X for a right-sided pleural effusion for thoracentesis.\",\n",
    "    ]\n",
    ")\n",
    "\n",
    "df"
   ],
   "metadata": {
    "id": "yI1U1UvA2mzg",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 602
    },
    "outputId": "9f69949e-1100-46aa-9f9e-efef7ddbd9fd"
   },
   "execution_count": 7,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "sentence_detector_dl download started this may take some time.\n",
      "Approximate size to download 354.6 KB\n",
      "[OK!]\n",
      "Warning::Spark Session already created, some configs may not take.\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "                                            document  \\\n",
       "0  The doctor pescribed Majezik for my severe hea...   \n",
       "0  The doctor pescribed Majezik for my severe hea...   \n",
       "1  The patient was admitted to the hospital for h...   \n",
       "2  27 years old patient was admitted to clinic on...   \n",
       "2  27 years old patient was admitted to clinic on...   \n",
       "2  27 years old patient was admitted to clinic on...   \n",
       "\n",
       "                                 entities_zero_shot entities_zero_shot_class  \\\n",
       "0                                           Majezik                     DRUG   \n",
       "0                                   severe headache                  PROBLEM   \n",
       "1                                      colon cancer                  PROBLEM   \n",
       "2                                      27 years old              PATIENT_AGE   \n",
       "2                                           Sep 1st           ADMISSION_DATE   \n",
       "2  a right-sided pleural effusion for thoracentesis                  PROBLEM   \n",
       "\n",
       "  entities_zero_shot_confidence entities_zero_shot_origin_chunk  \\\n",
       "0                    0.64671576                               0   \n",
       "0                     0.5526346                               1   \n",
       "1                     0.8898498                               0   \n",
       "2                     0.6943085                               0   \n",
       "2                    0.95646095                               1   \n",
       "2                    0.50026613                               2   \n",
       "\n",
       "  entities_zero_shot_origin_sentence  \n",
       "0                                  0  \n",
       "0                                  0  \n",
       "1                                  0  \n",
       "2                                  0  \n",
       "2                                  0  \n",
       "2                                  0  "
      ],
      "text/html": [
       "\n",
       "  <div id=\"df-76287946-82eb-4b47-af04-5dfcbcac80fc\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>document</th>\n",
       "      <th>entities_zero_shot</th>\n",
       "      <th>entities_zero_shot_class</th>\n",
       "      <th>entities_zero_shot_confidence</th>\n",
       "      <th>entities_zero_shot_origin_chunk</th>\n",
       "      <th>entities_zero_shot_origin_sentence</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>The doctor pescribed Majezik for my severe hea...</td>\n",
       "      <td>Majezik</td>\n",
       "      <td>DRUG</td>\n",
       "      <td>0.64671576</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>The doctor pescribed Majezik for my severe hea...</td>\n",
       "      <td>severe headache</td>\n",
       "      <td>PROBLEM</td>\n",
       "      <td>0.5526346</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>The patient was admitted to the hospital for h...</td>\n",
       "      <td>colon cancer</td>\n",
       "      <td>PROBLEM</td>\n",
       "      <td>0.8898498</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>27 years old patient was admitted to clinic on...</td>\n",
       "      <td>27 years old</td>\n",
       "      <td>PATIENT_AGE</td>\n",
       "      <td>0.6943085</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>27 years old patient was admitted to clinic on...</td>\n",
       "      <td>Sep 1st</td>\n",
       "      <td>ADMISSION_DATE</td>\n",
       "      <td>0.95646095</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>27 years old patient was admitted to clinic on...</td>\n",
       "      <td>a right-sided pleural effusion for thoracentesis</td>\n",
       "      <td>PROBLEM</td>\n",
       "      <td>0.50026613</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-76287946-82eb-4b47-af04-5dfcbcac80fc')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-76287946-82eb-4b47-af04-5dfcbcac80fc button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-76287946-82eb-4b47-af04-5dfcbcac80fc');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-f07397a9-5e5a-4390-9309-bcf80400d008\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-f07397a9-5e5a-4390-9309-bcf80400d008')\"\n",
       "            title=\"Suggest charts.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-f07397a9-5e5a-4390-9309-bcf80400d008 button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ]
     },
     "metadata": {},
     "execution_count": 7
    }
   ]
  }
 ]
}
